Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve custom_jvp/vjp error messages #12611

Merged

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Oct 1, 2022

In particular:

  • add function names so it's clear what decorated functions and rules are causing the error;
  • when possible (because the functions were run), check for agreement of pytree structure and leaf shapes/dtypes between the primal function and rules

The latter will make the error message from this issue much better! Instead of this:

# ERROR ON MAIN
... giant stack trace ...
  File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 196, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
  File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 396, in eval_jaxpr
    map(write, eqn.outvars, ans)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/util.py", line 48, in safe_map
    assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
AssertionError: length mismatch: [3, 1]

we get this:

# ERROR AFTER THIS PR
TypeError: Custom VJP fwd rule flash_attention_forward for function
flash_attention must produce a pair (list or tuple of length two) where the
first element represents the primal output (equal to the output of the
custom_vjp-decorated function flash_attention) and the second element
represents residuals (i.e. values stored from the forward pass for use on the
backward pass), but instead the fwd rule output's first element had
container/pytree structure:
    float32[3,16,5,19]
while the custom_vjp-decorated function flash_attention had output
container/pytree structure:
    (float32[3,16,5,19], (float32[3,16,5], float32[3,16,5])).

@mattjj mattjj requested a review from froystig October 1, 2022 04:45
@mattjj mattjj force-pushed the custom-vjp-improve-type-error-checking branch from 955502d to 430f3d9 Compare October 1, 2022 04:48
@mattjj mattjj added the better_errors Improve the error reporting label Oct 1, 2022
@mattjj mattjj force-pushed the custom-vjp-improve-type-error-checking branch from 430f3d9 to 903326f Compare October 1, 2022 05:13
In particular:
* add function names so it's clear what decorated functions and rules
  are causing the error;
* when possible (because the functions were run), check for agreement of pytree
  structure and leaf shapes/dtypes between the primal function and rules

context: lucidrains/flash-attention-jax#7
@mattjj mattjj force-pushed the custom-vjp-improve-type-error-checking branch from 903326f to b8c87bc Compare October 1, 2022 05:41
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Oct 2, 2022
@copybara-service copybara-service bot merged commit 6318fdc into jax-ml:main Oct 3, 2022
@mattjj mattjj deleted the custom-vjp-improve-type-error-checking branch October 3, 2022 22:27
copybara-service bot pushed a commit that referenced this pull request Oct 18, 2022
PiperOrigin-RevId: 481984707
blurgyy added a commit to blurgyy/jaxngp that referenced this pull request Jun 13, 2023
The commit that (possibly) introduced the regression: <jax-ml/jax#12611>

It does not yet affect the current jaxngp codebase (because our JAX
version does not contain the commit that introduced this regression),
but does affect the jaxngp codebase if the nixpkgs flake is bumped to
the 23.05 tag, where the commit has been packaged to nixpkgs while the
revert <jax-ml/jax#12852> has not.

This commit thus works as a future-proof, and also makes the custom
differentiation implementation consistent between the two custom CUDA
extensions `jax-tcnn` and `volume-rendering-jax`.

Signed-off-by: Gaoyang Zhang <gy@blurgy.xyz>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants